{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Copyright (c) MONAI Consortium  \n",
    "Licensed under the Apache License, Version 2.0 (the \"License\");  \n",
    "you may not use this file except in compliance with the License.  \n",
    "You may obtain a copy of the License at  \n",
    "&nbsp;&nbsp;&nbsp;&nbsp;http://www.apache.org/licenses/LICENSE-2.0  \n",
    "Unless required by applicable law or agreed to in writing, software  \n",
    "distributed under the License is distributed on an \"AS IS\" BASIS,  \n",
    "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  \n",
    "See the License for the specific language governing permissions and  \n",
    "limitations under the License."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "view-in-github"
   },
   "source": [
    "You can download and run this notebook locally, or you can run it for free in a cloud environment using Colab or Sagemaker Studio Lab:\n",
    "\n",
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Project-MONAI/tutorials/blob/main/modules/tcia_dataset.ipynb)\n",
    "\n",
    "[![Open In SageMaker Studio Lab](https://studiolab.sagemaker.aws/studiolab.svg)](https://studiolab.sagemaker.aws/import/github.com/Project-MONAI/tutorials/blob/main/modules/tcia_dataset.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "zKNxVGTHBTHt"
   },
   "source": [
    "# Quick Start With TciaDataset\n",
    "\n",
    "[The Cancer Imaging Archive (TCIA)](https://www.cancerimagingarchive.net/) is a service which de-identifies and hosts a large publicly available archive of medical images of cancer.  TCIA is funded by the [Cancer Imaging Program (CIP)](https://imaging.cancer.gov/), a part of the United States [National Cancer Institute (NCI)](https://www.cancer.gov/), and is managed by the [Frederick National Laboratory for Cancer Research (FNLCR)](https://frederick.cancer.gov/).\n",
    "\n",
    "In this tutorial, we will introduce how to use `TciaDataset` to automatically download and extract the TCIA datasets with accompanying DICOM segmentations, and act as PyTorch datasets to generate training/validation/test data.\n",
    "\n",
    "<p align=\"center\">\n",
    "  <img src=\"../figures/TCIA_LOGO.png\" alt=\"tcia logo\">\n",
    "</p>\n",
    "\n",
    "We'll cover the following topics in this tutorial:\n",
    "* Access a dataset(collection) from TCIA with TciaDataset\n",
    "* Visualize images and masks from the dataset\n",
    "* Create training experiment with TciaDataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "GTBal5mgBTHu"
   },
   "source": [
    "## Setup environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ESh_cJ5WBTHu"
   },
   "outputs": [],
   "source": [
    "!python -c \"import monai\" || pip install -q \"monai-weekly[nibabel, pillow, ignite, tqdm, pydicom]\"\n",
    "!python -c \"import matplotlib\" || pip install -q matplotlib"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "id": "fnIg3CdPBTHv",
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MONAI version: 1.1.0+11.g7de6c336.dirty\n",
      "Numpy version: 1.22.2\n",
      "Pytorch version: 1.13.0a0+d0d6b1f\n",
      "MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False\n",
      "MONAI rev id: 7de6c33656a99087ca3b89a817b0879cf093febc\n",
      "MONAI __file__: /workspace/Code/MONAI/monai/__init__.py\n",
      "\n",
      "Optional dependencies:\n",
      "Pytorch Ignite version: 0.4.10\n",
      "Nibabel version: 4.0.2\n",
      "scikit-image version: 0.19.3\n",
      "Pillow version: 9.0.1\n",
      "Tensorboard version: 2.11.0\n",
      "gdown version: 4.6.0\n",
      "TorchVision version: 0.14.0a0\n",
      "tqdm version: 4.64.1\n",
      "lmdb version: 1.3.0\n",
      "psutil version: 5.9.2\n",
      "pandas version: 1.4.4\n",
      "einops version: 0.6.0\n",
      "transformers version: 4.21.3\n",
      "mlflow version: 2.0.1\n",
      "pynrrd version: 1.0.0\n",
      "\n",
      "For details about installing the optional dependencies, please visit:\n",
      "    https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from monai.transforms import (\n",
    "    Activations,\n",
    "    AsDiscrete,\n",
    "    EnsureChannelFirstd,\n",
    "    Compose,\n",
    "    LoadImaged,\n",
    "    Orientationd,\n",
    "    RandCropByPosNegLabeld,\n",
    "    RandScaleIntensityd,\n",
    "    RandShiftIntensityd,\n",
    "    RandFlipd,\n",
    "    Spacingd,\n",
    "    NormalizeIntensityd,\n",
    "    ResampleToMatchd,\n",
    ")\n",
    "import monai\n",
    "from monai.config import print_config\n",
    "from monai.networks.nets import UNet\n",
    "from monai.metrics import DiceMetric\n",
    "from monai.losses import DiceCELoss\n",
    "from monai.inferers import sliding_window_inference\n",
    "from monai.apps import TciaDataset\n",
    "from monai.apps.tcia import TCIA_LABEL_DICT\n",
    "from monai.data import DataLoader, decollate_batch\n",
    "import os\n",
    "import requests\n",
    "import tempfile\n",
    "import shutil\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "print_config()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "6mxYYFvPBTHw"
   },
   "source": [
    "## Setup data directory\n",
    "\n",
    "You can specify a directory with the `MONAI_DATA_DIRECTORY` environment variable.  \n",
    "This allows you to save results and reuse downloads.  \n",
    "If not specified a temporary directory will be used."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "id": "wbTvWiEaBTHw",
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/workspace/Data\n"
     ]
    }
   ],
   "source": [
    "directory = os.environ.get(\"MONAI_DATA_DIRECTORY\")\n",
    "if directory is not None:\n",
    "    os.makedirs(directory, exist_ok=True)\n",
    "root_dir = tempfile.mkdtemp() if directory is None else directory\n",
    "print(root_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Y6mNIzYOBTHw"
   },
   "source": [
    "## Brief Introduction\n",
    "\n",
    "Documentation about TCIA's REST APIs can be found at https://wiki.cancerimagingarchive.net/x/NIIiAQ. A full list of TCIA's Collections can be found at https://www.cancerimagingarchive.net/collections/. This is the best place to learn about the available datasets and discover supporting data including clinical spreadsheets, image features and other non-DICOM data that aren't available via the API (e.g NIfTI segmenation labels). However, this tutorial is focused on the use of DICOM segmentations.  You can use the following code to find out which collections have this kind of data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "id": "pAGY_i7SktMB"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ACRIN-6698 has 385 patients with SEG modality\n",
      "Breast-MRI-NACT-Pilot has 64 patients with SEG modality\n",
      "C4KC-KiTS has 210 patients with SEG modality\n",
      "Colorectal-Liver-Metastases has 197 patients with SEG modality\n",
      "DRO-Toolkit has 32 patients with SEG modality\n",
      "Duke-Breast-Cancer-MRI has 127 patients with SEG modality\n",
      "HCC-TACE-Seg has 105 patients with SEG modality\n",
      "ISPY1 has 207 patients with SEG modality\n",
      "ISPY2 has 719 patients with SEG modality\n",
      "LIDC-IDRI has 875 patients with SEG modality\n",
      "Lung Phantom has 1 patients with SEG modality\n",
      "Lung-Fused-CT-Pathology has 6 patients with SEG modality\n",
      "NSCLC Radiogenomics has 144 patients with SEG modality\n",
      "NSCLC-Radiomics has 421 patients with SEG modality\n",
      "NSCLC-Radiomics-Interobserver1 has 21 patients with SEG modality\n",
      "PROSTATEx has 98 patients with SEG modality\n",
      "QIBA CT-1C has 1 patients with SEG modality\n",
      "QIN LUNG CT has 10 patients with SEG modality\n",
      "QIN-PROSTATE-Repeatability has 15 patients with SEG modality\n",
      "RIDER Lung CT has 31 patients with SEG modality\n",
      "The following collections have no patients with SEG modality: ['4D-Lung', 'ACRIN-Contralateral-Breast-MR', 'ACRIN-FLT-Breast', 'ACRIN-NSCLC-FDG-PET', 'Anti-PD-1_Lung', 'B-mode-and-CEUS-Liver', 'BREAST-DIAGNOSIS', 'Breast-Cancer-Screening-DBT', 'CBIS-DDSM', 'CC-Radiomics-Phantom', 'CC-Radiomics-Phantom-2', 'CC-Radiomics-Phantom-3', 'CMB-CRC', 'CMB-GEC', 'CMB-LCA', 'CMB-MEL', 'CMB-MML', 'CMB-PCA', 'CMMD', 'COVID-19-AR', 'COVID-19-NY-SBU', 'CPTAC-CCRCC', 'CPTAC-CM', 'CPTAC-LSCC', 'CPTAC-LUAD', 'CPTAC-PDA', 'CPTAC-SAR', 'CPTAC-UCEC', 'CT COLONOGRAPHY', 'CT Lymph Nodes', 'CT-vs-PET-Ventilation-Imaging', 'CTpred-Sunitinib-panNET', 'GBM-DSC-MRI-DRO', 'ICDC-Glioma', 'LCTSC', 'Lung-PET-CT-Dx', 'LungCT-Diagnosis', 'MIDRC-RICORD-1A', 'MIDRC-RICORD-1B', 'MIDRC-RICORD-1C', 'Mouse-Astrocytoma', 'Mouse-Mammary', 'NSCLC-Radiomics-Genomics', 'NaF PROSTATE', 'PDMR-292921-168-R', 'PDMR-425362-245-T', 'PDMR-521955-158-R4', 'PDMR-833975-119-R', 'PDMR-997537-175-T', 'PDMR-BL0293-F563', 'PROSTATE-DIAGNOSIS', 'PROSTATE-MRI', 'Pancreas-CT', 'Pancreatic-CT-CBCT-SEG', 'Pediatric-CT-SEG', 'Pelvic-Reference-Data', 'Phantom FDA', 'Prostate Fused-MRI-Pathology', 'Prostate-3T', 'Prostate-MRI-US-Biopsy', 'Pseudo-PHI-DICOM-Data', 'QIBA-CT-Liver-Phantom', 'QIN Breast DCE-MRI', 'QIN PET Phantom', 'QIN-BREAST', 'RIDER Breast MRI', 'RIDER Lung PET-CT', 'RIDER PHANTOM MRI', 'RIDER PHANTOM PET-CT', 'SPIE-AAPM Lung CT Challenge', 'Soft-tissue-Sarcoma', 'StageII-Colorectal-CT', 'TCGA-BLCA', 'TCGA-BRCA', 'TCGA-CESC', 'TCGA-COAD', 'TCGA-ESCA', 'TCGA-KICH', 'TCGA-KIRC', 'TCGA-KIRP', 'TCGA-LIHC', 'TCGA-LUAD', 'TCGA-LUSC', 'TCGA-OV', 'TCGA-PRAD', 'TCGA-READ', 'TCGA-SARC', 'TCGA-STAD', 'TCGA-THCA', 'TCGA-UCEC', 'UPENN-GBM', 'VICTRE', 'Vestibular-Schwannoma-SEG']\n"
     ]
    }
   ],
   "source": [
    "# get patient counts Collections that have SEG modality\n",
    "\n",
    "base_url = \"https://services.cancerimagingarchive.net/nbia-api/services/v1/\"\n",
    "modality = \"SEG\"\n",
    "\n",
    "collection_url = base_url + \"getCollectionValues\"\n",
    "collection_data = requests.get(collection_url).json()\n",
    "\n",
    "not_found = []\n",
    "for x in collection_data:\n",
    "    collection_name = x[\"Collection\"]\n",
    "    patient_url = base_url + \"getPatientByCollectionAndModality?Collection=\" + collection_name + \"&Modality=\" + modality\n",
    "    patients = requests.get(patient_url)\n",
    "    if patients.text != \"\":\n",
    "        patients = patients.json()\n",
    "        print(collection_name, \"has\", len(patients), \"patients with\", modality, \"modality\")\n",
    "    else:\n",
    "        not_found.append(collection_name)\n",
    "\n",
    "print(\"The following collections have no patients with\", modality, \"modality:\", not_found)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If you'd like to learn more about a Collection you discovered in the last step, you can look them up on https://www.cancerimagingarchive.net/collections/ or modify the following cell to obtain some basic information about it via the API. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "QIN-PROSTATE-Repeatability has 15 patients, {'MR', 'SR', 'SEG'} modalities, and {'Not Specified', 'PROSTATE'} anatomic entities\n"
     ]
    }
   ],
   "source": [
    "# set collection of interest\n",
    "collection_name = \"QIN-PROSTATE-Repeatability\"\n",
    "\n",
    "# look up number of patients, modalities and body parts examined\n",
    "patient_url = base_url + \"getPatient?Collection=\" + collection_name\n",
    "patients = requests.get(patient_url)\n",
    "if patients.text != \"\":\n",
    "    patients = patients.json()\n",
    "    clean_patient_ids = {item[\"PatientId\"] for item in patients}\n",
    "    modality_url = base_url + \"getModalityValues?Collection=\" + collection_name\n",
    "    modalities = requests.get(modality_url).json()\n",
    "    clean_modalities = {item[\"Modality\"] for item in modalities}\n",
    "    bodypart_url = base_url + \"getBodyPartValues?Collection=\" + collection_name\n",
    "    bodyparts = requests.get(bodypart_url).json()\n",
    "    clean_bodyparts = set()\n",
    "    for item in bodyparts:\n",
    "        if len(item):\n",
    "            clean_bodyparts.add(item[\"BodyPartExamined\"])\n",
    "        else:\n",
    "            clean_bodyparts.add(\"Not Specified\")\n",
    "    print(\n",
    "        collection_name,\n",
    "        \"has\",\n",
    "        len(clean_patient_ids),\n",
    "        \"patients,\",\n",
    "        clean_modalities,\n",
    "        \"modalities, and\",\n",
    "        clean_bodyparts,\n",
    "        \"anatomic entities\",\n",
    "    )\n",
    "else:\n",
    "    print(\"Collection not found.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "B-1NQ7UnBTHw"
   },
   "source": [
    "## Create TciaDataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "7rjH5zU7BTHw"
   },
   "source": [
    "For a collection that contains image with type `SEG`, `TciaDataset` (inherits from MONAI `CacheDataset`) can be utilized to download the dataset and load images. We hope to add support for RTSTRUCT at a later date.  \n",
    "\n",
    "`TciaDataset` includes some of the NBIA APIs, and provides rich parameters to achieve expected behavior. Some of the main parameters include:\n",
    "1. **root_dir**: target directory to download and load TCIA dataset.\n",
    "1. **collection**: name of collection (dataset)\n",
    "1. **download**: whether to download and extract the dataset first.\n",
    "1. **download_len**: number of series that will be downloaded. Sometimes only download a few series is enough. If not specified, the whole dataset will be downloaded.\n",
    "1. **seg_type**: modality type of segmentation that is used to do the first step download (two steps download will be introduced below)\n",
    "1. **modality_tag**: tag of modality.\n",
    "1. **ref_series_uid_tag**: tag of referenced Series Instance UID.\n",
    "1. **ref_sop_uid_tag**: tag of referenced SOP Instance UID.\n",
    "\n",
    "After downloading (if specified, all images and \"seg_type\" data will be stored locally for further usage.\n",
    "\n",
    "<p align=\"center\">\n",
    "  <img src=\"../figures/tcia_download_steps.png\" alt=\"tcia download steps\">\n",
    "</p>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "id": "h5TFr6LUBTHx"
   },
   "outputs": [],
   "source": [
    "# Let's take the \"QIN-PROSTATE-Repeatability\" collection for example\n",
    "collection, seg_type = \"QIN-PROSTATE-Repeatability\", \"SEG\"\n",
    "\n",
    "# use the `LoadImaged` transform to load data, `label_dict` is used to unify the label order\n",
    "transform = Compose(\n",
    "    [\n",
    "        LoadImaged(reader=\"PydicomReader\", keys=[\"image\", \"seg\"], label_dict=TCIA_LABEL_DICT[collection]),\n",
    "    ]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "uWSKcQ1ABTHx",
    "tags": []
   },
   "outputs": [],
   "source": [
    "ds = TciaDataset(\n",
    "    root_dir=root_dir,\n",
    "    collection=collection,\n",
    "    section=\"training\",\n",
    "    download=True,\n",
    "    seg_type=seg_type,\n",
    "    progress=False,\n",
    "    cache_rate=0.0,\n",
    "    val_frac=0.2,\n",
    "    transform=transform,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "YfFtFDuuBTHx"
   },
   "source": [
    "### Detect dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "v7SvgzqLBTHx"
   },
   "outputs": [],
   "source": [
    "sample = 0\n",
    "\n",
    "ds[sample][\"image\"].affine"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "tGlqDXVSBTHy"
   },
   "outputs": [],
   "source": [
    "ds[sample][\"seg\"].affine"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "TfhX9310BTHy"
   },
   "outputs": [],
   "source": [
    "print(ds[sample][\"image\"].shape, ds[sample][\"seg\"].shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "cBbrZ6ltBTHy"
   },
   "source": [
    "As we can see from the metadata, the affine matrixes between images and segmentations have some differences: the direction of the third spatial dimension is opposite.\n",
    "\n",
    "We can use the `ResampleToMatchd` transform to unify them.\n",
    "In addition, images do not have the channel dimension, and the channel dimension of segmentations can be adjusted into the first dimension. We can use the `EnsureChannelFirstd` to address this issue"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ihIZCbKCBTHy"
   },
   "source": [
    "### Add pre-processing transforms and reload"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "tEw0IE6iBTHy"
   },
   "outputs": [],
   "source": [
    "transform = Compose(\n",
    "    [\n",
    "        LoadImaged(reader=\"PydicomReader\", keys=[\"image\", \"seg\"], label_dict=TCIA_LABEL_DICT[collection]),\n",
    "        EnsureChannelFirstd(keys=[\"image\", \"seg\"]),\n",
    "        ResampleToMatchd(keys=\"image\", key_dst=\"seg\"),\n",
    "    ]\n",
    ")\n",
    "\n",
    "ds = TciaDataset(\n",
    "    root_dir=root_dir,\n",
    "    collection=collection,\n",
    "    section=\"training\",\n",
    "    download=False,\n",
    "    seg_type=seg_type,\n",
    "    progress=False,\n",
    "    cache_rate=0.0,\n",
    "    val_frac=0.2,\n",
    "    transform=transform,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "vM17g0t7BTHy"
   },
   "outputs": [],
   "source": [
    "sample = 0\n",
    "\n",
    "ds[sample][\"image\"].affine"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "6tHnUdmGBTHy"
   },
   "outputs": [],
   "source": [
    "ds[sample][\"seg\"].affine"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "IzXeM06mBTHz"
   },
   "outputs": [],
   "source": [
    "print(ds[sample][\"image\"].shape, ds[sample][\"seg\"].shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "0XNwkVjqBTHz"
   },
   "source": [
    "As we can see, now the affine matrixes for images and segmentations are identical."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "PuLmSufTBTHz"
   },
   "source": [
    "### Visualize and check"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "cfDKvR2WBTHz"
   },
   "outputs": [],
   "source": [
    "# from W, H, D to D, H, W\n",
    "img = ds[4][\"image\"].numpy().transpose([0, 3, 2, 1])\n",
    "seg = ds[4][\"seg\"].numpy().transpose([0, 3, 2, 1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Vq8S6CAaBTHz"
   },
   "outputs": [],
   "source": [
    "every_n = 10\n",
    "fig_seg = monai.visualize.matshow3d(img[0], show=True, every_n=every_n)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "tt5tP1grBTHz"
   },
   "outputs": [],
   "source": [
    "fig_seg = monai.visualize.matshow3d(seg[3], show=True, every_n=every_n)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "QiGE57NbBTHz"
   },
   "source": [
    "## Create a Training Pipeline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "YwxDP5WQBTHz"
   },
   "source": [
    "### Define CacheDataset and DataLoader for training and validation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "jSzfpzW1BTHz",
    "lines_to_next_cell": 2
   },
   "outputs": [],
   "source": [
    "train_transforms = Compose(\n",
    "    [\n",
    "        LoadImaged(\n",
    "            reader=\"PydicomReader\", keys=[\"image\", \"seg\"], label_dict=TCIA_LABEL_DICT[collection], image_only=True\n",
    "        ),\n",
    "        EnsureChannelFirstd(keys=[\"image\", \"seg\"]),\n",
    "        ResampleToMatchd(keys=\"image\", key_dst=\"seg\"),\n",
    "        Orientationd(keys=[\"image\", \"seg\"], axcodes=\"RAS\"),\n",
    "        Spacingd(keys=[\"image\", \"seg\"], pixdim=(1.0, 1.0, 0.5), mode=(\"bilinear\", \"nearest\")),\n",
    "        NormalizeIntensityd(keys=\"image\", nonzero=True, channel_wise=True),\n",
    "        RandCropByPosNegLabeld(\n",
    "            keys=[\"image\", \"seg\"],\n",
    "            label_key=\"seg\",\n",
    "            spatial_size=(96, 96, 64),\n",
    "            pos=1,\n",
    "            neg=1,\n",
    "            num_samples=2,\n",
    "            image_key=\"image\",\n",
    "            image_threshold=0,\n",
    "        ),\n",
    "        RandFlipd(keys=(\"image\", \"seg\"), prob=0.5, spatial_axis=[0]),\n",
    "        RandFlipd(keys=(\"image\", \"seg\"), prob=0.5, spatial_axis=[1]),\n",
    "        RandScaleIntensityd(keys=\"image\", factors=(-0.2, 0.2), prob=0.5),\n",
    "        RandShiftIntensityd(keys=\"image\", offsets=(-0.1, 0.1), prob=0.5),\n",
    "    ]\n",
    ")\n",
    "\n",
    "val_transforms = Compose(\n",
    "    [\n",
    "        LoadImaged(\n",
    "            reader=\"PydicomReader\", keys=[\"image\", \"seg\"], label_dict=TCIA_LABEL_DICT[collection], image_only=True\n",
    "        ),\n",
    "        EnsureChannelFirstd(keys=[\"image\", \"seg\"]),\n",
    "        ResampleToMatchd(keys=\"image\", key_dst=\"seg\"),\n",
    "        Orientationd(keys=[\"image\", \"seg\"], axcodes=\"RAS\"),\n",
    "        Spacingd(keys=[\"image\", \"seg\"], pixdim=(1.0, 1.0, 0.5), mode=(\"bilinear\", \"nearest\")),\n",
    "        NormalizeIntensityd(keys=\"image\", nonzero=True, channel_wise=True),\n",
    "    ]\n",
    ")\n",
    "\n",
    "train_ds = TciaDataset(\n",
    "    root_dir=root_dir,\n",
    "    collection=collection,\n",
    "    section=\"training\",\n",
    "    download=False,\n",
    "    seg_type=seg_type,\n",
    "    progress=True,\n",
    "    cache_rate=1.0,\n",
    "    val_frac=0.2,\n",
    "    transform=train_transforms,\n",
    ")\n",
    "\n",
    "val_ds = TciaDataset(\n",
    "    root_dir=root_dir,\n",
    "    collection=collection,\n",
    "    section=\"training\",\n",
    "    download=False,\n",
    "    seg_type=seg_type,\n",
    "    progress=True,\n",
    "    cache_rate=1.0,\n",
    "    val_frac=0.2,\n",
    "    transform=val_transforms,\n",
    ")\n",
    "\n",
    "train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4)\n",
    "val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ijsFlq3qBTH0"
   },
   "outputs": [],
   "source": [
    "# there are 4 labels, and for the example here, we will only focus on \"WholeGland\".\n",
    "\n",
    "print(train_ds[0][0][\"seg\"].meta[\"labels\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "b8WmN_qxBTH0"
   },
   "source": [
    "### Create Model, Loss, Optimizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "FNxlkHtqBTH0"
   },
   "outputs": [],
   "source": [
    "# standard PyTorch program style: create UNet, DiceLoss and Adam optimizer\n",
    "device = torch.device(\"cuda:0\")\n",
    "model = UNet(\n",
    "    spatial_dims=3,\n",
    "    in_channels=1,\n",
    "    out_channels=1,\n",
    "    channels=(16, 32, 64, 128, 256),\n",
    "    strides=(2, 2, 2, 2),\n",
    "    num_res_units=2,\n",
    "    norm=\"batch\",\n",
    ").to(device)\n",
    "loss_function = DiceCELoss(sigmoid=True, include_background=True, batch=True, squared_pred=True)\n",
    "optimizer = torch.optim.Adam(model.parameters(), 3e-4)\n",
    "max_epochs = 1000\n",
    "lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epochs)\n",
    "dice_metric = DiceMetric(include_background=True, reduction=\"mean\")\n",
    "selected_label = 3"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "FvILFYmbBTH0"
   },
   "source": [
    "### Execute a typical PyTorch training process"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "7RyTX__vBTH0"
   },
   "outputs": [],
   "source": [
    "val_interval = 2\n",
    "best_metric = -1\n",
    "best_metric_epoch = -1\n",
    "epoch_loss_values = []\n",
    "metric_values = []\n",
    "post_pred = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])\n",
    "\n",
    "for epoch in range(max_epochs):\n",
    "    print(\"-\" * 10)\n",
    "    print(f\"epoch {epoch + 1}/{max_epochs}\")\n",
    "    model.train()\n",
    "    epoch_loss = 0\n",
    "    step = 0\n",
    "    for batch_data in train_loader:\n",
    "        step += 1\n",
    "        inputs, labels = (\n",
    "            batch_data[\"image\"].to(device),\n",
    "            batch_data[\"seg\"][:, selected_label : selected_label + 1, ...].to(device),\n",
    "        )\n",
    "        optimizer.zero_grad()\n",
    "        outputs = model(inputs)\n",
    "        loss = loss_function(outputs, labels)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        epoch_loss += loss.item()\n",
    "    lr_scheduler.step()\n",
    "    epoch_loss /= step\n",
    "    epoch_loss_values.append(epoch_loss)\n",
    "    print(f\"epoch {epoch + 1} average loss: {epoch_loss:.4f}\")\n",
    "\n",
    "    if (epoch + 1) % val_interval == 0:\n",
    "        model.eval()\n",
    "        with torch.no_grad():\n",
    "            for val_data in val_loader:\n",
    "                val_inputs, val_labels = (\n",
    "                    val_data[\"image\"].to(device),\n",
    "                    val_data[\"seg\"][:, selected_label : selected_label + 1, ...].to(device),\n",
    "                )\n",
    "                roi_size = (96, 96, 64)\n",
    "                sw_batch_size = 4\n",
    "                val_outputs = sliding_window_inference(val_inputs, roi_size, sw_batch_size, model)\n",
    "                val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)]\n",
    "                # compute metric for current iteration\n",
    "                dice_metric(y_pred=val_outputs, y=val_labels)\n",
    "\n",
    "            # aggregate the final mean dice result\n",
    "            metric = dice_metric.aggregate().item()\n",
    "            # reset the status for next validation round\n",
    "            dice_metric.reset()\n",
    "\n",
    "            metric_values.append(metric)\n",
    "            if metric > best_metric:\n",
    "                best_metric = metric\n",
    "                best_metric_epoch = epoch + 1\n",
    "                torch.save(model.state_dict(), os.path.join(root_dir, \"best_metric_model.pth\"))\n",
    "                print(\"saved new best metric model\")\n",
    "            print(\n",
    "                f\"current epoch: {epoch + 1} current mean dice: {metric:.4f}\"\n",
    "                f\"\\nbest mean dice: {best_metric:.4f} \"\n",
    "                f\"at epoch: {best_metric_epoch}\"\n",
    "            )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "7O8oqoqJBTH0"
   },
   "outputs": [],
   "source": [
    "print(f\"train completed, best_metric: {best_metric:.4f} \" f\"at epoch: {best_metric_epoch}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "hXZ_Z_PeBTH1"
   },
   "source": [
    "## Plot the loss and metric"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "F6YsgbWSBTH1"
   },
   "outputs": [],
   "source": [
    "plt.figure(\"train\", (12, 6))\n",
    "plt.subplot(1, 2, 1)\n",
    "plt.title(\"Epoch Average Loss\")\n",
    "x = [i + 1 for i in range(len(epoch_loss_values))]\n",
    "y = epoch_loss_values\n",
    "plt.xlabel(\"epoch\")\n",
    "plt.plot(x, y)\n",
    "plt.subplot(1, 2, 2)\n",
    "plt.title(\"Val Mean Dice\")\n",
    "x = [val_interval * (i + 1) for i in range(len(metric_values))]\n",
    "y = metric_values\n",
    "plt.xlabel(\"epoch\")\n",
    "plt.plot(x, y)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "EfnbXpb4BTH1"
   },
   "source": [
    "## Check best model output with the input image and label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "VokqAPXDBTH1"
   },
   "outputs": [],
   "source": [
    "model.load_state_dict(torch.load(os.path.join(root_dir, \"best_metric_model.pth\"), weights_only=True))\n",
    "model.eval()\n",
    "with torch.no_grad():\n",
    "    for i, val_data in enumerate(val_loader):\n",
    "        roi_size = (96, 96, 64)\n",
    "        sw_batch_size = 4\n",
    "        val_outputs = sliding_window_inference(val_data[\"image\"].to(device), roi_size, sw_batch_size, model)\n",
    "        depth = val_data[\"image\"].shape[-1]\n",
    "        slc = depth // 2\n",
    "        plt.figure(\"check\", (18, 6))\n",
    "        plt.subplot(1, 3, 1)\n",
    "        plt.title(f\"image {i}\")\n",
    "        plt.imshow(val_data[\"image\"][0, 0, :, :, slc], cmap=\"gray\")\n",
    "        plt.subplot(1, 3, 2)\n",
    "        plt.title(f\"label {i}\")\n",
    "        plt.imshow(val_data[\"seg\"][0, selected_label, :, :, slc])\n",
    "        plt.subplot(1, 3, 3)\n",
    "        plt.title(f\"output {i}\")\n",
    "        plt.imshow(post_pred(val_outputs[0]).detach().cpu()[0, :, :, slc])\n",
    "        plt.show()\n",
    "        if i == 2:\n",
    "            break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "tNDPYjY7BTH2"
   },
   "source": [
    "## Cleanup data directory\n",
    "\n",
    "Remove directory if a temporary was used."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "cl9XjjOFBTH2"
   },
   "outputs": [],
   "source": [
    "if directory is None:\n",
    "    shutil.rmtree(root_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Fv7KGlp0BTH2"
   },
   "source": [
    "# Acknowledgments\n",
    "\n",
    "Users of the [QIN-Prostate-Repeatability](http://doi.org/10.7937/K9/TCIA.2018.MR1CKGND) dataset used in this notebook must abide by the [TCIA Data Usage Policy](https://wiki.cancerimagingarchive.net/x/c4hF) and the [Creative Commons Attribution 4.0 International License](https://creativecommons.org/licenses/by/4.0/) under which it has been published. Attribution should include references to the following citations:\n",
    "\n",
    "* Fedorov, A; Schwier, M; Clunie, D; Herz, C; Pieper, S; Kikinis, R; Tempany, C; Fennessy, F. (2018). Data From QIN-PROSTATE-Repeatability. The Cancer Imaging Archive. DOI: [10.7937/K9/TCIA.2018.MR1CKGND](http://doi.org/10.7937/K9/TCIA.2018.MR1CKGND)\n",
    "\n",
    "* Fedorov A, Vangel MG, Tempany CM, Fennessy FM. Multiparametric Magnetic Resonance Imaging of the Prostate: Repeatability of Volume and Apparent Diffusion Coefficient Quantification. Investigative Radiology. 52, 538–546 (2017). DOI: [10.1097/RLI.0000000000000382](https://doi.org/10.1097/RLI.0000000000000382)\n",
    "\n",
    "* Fedorov, A., Schwier, M., Clunie, D., Herz, C., Pieper, S., Kikinis,R., Tempany, C. & Fennessy, F. An annotated test-retest collection of prostate multiparametric MRI. Scientific Data 5, 180281 (2018). DOI: \n",
    "[10.1038/sdata.2018.281](https://doi.org/10.1038/sdata.2018.281)\n",
    "\n",
    "* Clark K, Vendt B, Smith K, Freymann J, Kirby J, Koppel P, Moore S, Phillips S, Maffitt D, Pringle M, Tarbox L, Prior F. The Cancer Imaging Archive (TCIA): Maintaining and Operating a Public Information Repository, Journal of Digital Imaging, Volume 26, Number 6, December, 2013, pp 1045-1057. DOI: [10.1007/s10278-013-9622-7](https://dx.doi.org/10.1007%2Fs10278-013-9622-7)\n"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "include_colab_link": true,
   "name": "tcia_dataset.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "vscode": {
   "interpreter": {
    "hash": "d4d1e4263499bec80672ea0156c357c1ee493ec2b1c70f0acce89fc37c4a6abe"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
